import os
import yaml
from omegaconf import OmegaConf
import torch

from saicinpainting.training.trainers import load_checkpoint

def get_mmt_model():
    train_config_path = os.path.join('big-lama', 'config.yaml')
    with open(train_config_path, 'r') as f:
        train_config = OmegaConf.create(yaml.safe_load(f))

    train_config.training_model.predict_only = True
    train_config.visualizer.kind = 'noop'
    checkpoint_path = os.path.join('big-lama', 'models', 'best.ckpt')
    model = load_checkpoint(train_config, checkpoint_path, map_location='cpu', strict=False)
    model.freeze()
    return model